Mxfp8 grouped and multi quantize#598
Open
alextmagro wants to merge 1 commit into
Open
Conversation
Enable group_quantize for AMD and add multi tensor quantize kernel for mxfp8
fe5fd03 to
bcfc909
Compare
13 tasks
ipanfilo
reviewed
May 30, 2026
| const int block_id_Y = blockIdx.y; | ||
| const int block_id_X = blockIdx.x; | ||
| const int dbias_y_offset = blockIdx.y; | ||
| #include "rocm_quantize_mxfp8_body.inc" |
Collaborator
There was a problem hiding this comment.
Can it be inline function instead?
| * \param[in] stream CUDA stream used for the operation. | ||
| */ | ||
| void nvte_multi_quantize_mxfp8(size_t num_tensors, const NVTETensor *input_list, | ||
| NVTETensor *output_list, cudaStream_t stream); |
Collaborator
There was a problem hiding this comment.
There is nvte_multi_tensor_quantize() here so probably it should be used instead of separate call
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Introduces grouped and multi quantize kernels for MXFP8. Grouped kernel requires later optimization, multi is a stand in replacement for MoE models
Improvements to quantize kernels via nontemporal stores
rocm quantize kernel body moved to .inc file to avoid repeated code. forceinline causes register spills and hurts performance for grouped/multi kernels
Branched off of IFU 2.14, merge target will be updated after IFU merge
Single Kernel Optimization Results
CAST_ONLY
DBIAS_DACT (GELU)
Gated SwiGLU (FWD)
Gated SwiGLU (BWD / dSwiGLU)
Grouped Quantize Kernel
Rowwise
Colwise
Both (Rowwise + Colwise)
Multi Quantize Kernel
Rowwise
Colwise
Both (Rowwise + Colwise)
Claude change summary
What Changed
Vec::store_to()withNTVec::nt_store()via zero-copyreinterpret_castfor all rowwise output pathsbulk_tensor_2d_shared_to_globalnow uses NT stores for all colwise output pathsrocm_quantize_mxfp8.cuhandrocm_gated_mxfp8.cuhin_sh[2],out_colwise_sh[2]) was tested and rejected (+21.7% average regression on CAST_ONLY) — without async global→LDS copies (TDM/TMA), the 2x LDS footprint kills occupancy without providing true load/compute overlap; revisit on MI450 with TDM.in_sh, colwise pass skips IS_ACT/IS_DACT recomputation andact_in_shreads. -29.8% average on DBIAS_DACT/both (FP16: -44%, BF16: -31%, FP32: -16%). Zero overhead for non-activation paths (constexpr guard).rocm_gated_mxfp8.cuh). Fixed the BF16/FP16 both-mode regression: FWD both BF16 485→413us (-15%), BWD both BF16 1019→798us (-22%). All gated both-mode configs now beat B200.nvte_multi_quantize_mxfp8) added for per-tensor pointer API used by Megatron'ssplit_quantize. UsesMultiQuantizeMXFP8Argsstruct with prefix-sumblock_rangeand 2D grid (blockIdx.x = col tiles, blockIdx.y = row tiles with binary search). Kernel body shared with single/grouped via.incfile inclusion (compiler refused to inline 500-line__device__ __forceinline__— 5 VGPRs). Balanced distributions: 4.20x rowwise, 3.59x colwise, 2.35x both vs per-tensor baseline. "Both" mode is slower than grouped (2.35x vs 2.80x) due to scattered per-tensor allocations hurting L2 locality — inherent to the per-tensor pointer API vs grouped's contiguous buffer.